from pymilvus import Collection
from app.app_init import device, model_L14, preprocess_L14, executor, model_H14, preprocess_H14, model_B16, \
    preprocess_B16
from app.utll.embedding_util import get_embedding, image_open
from datetime import datetime


def add_infos(image_ids, image_paths, user_id, process_batch_size, model_name, metric_type):
    """
    此函数主要是将图片向量化到向量数据库中
    :param image_ids: 图片id集合
    :param image_paths: 图片路径集合
    :param user_id: 操作的用户id
    :param process_batch_size: 批处理量
    :param model_name: 选用的模型
    :param metric_type: 索引方式（cosine、L2）
    :return:
    """

    if model_name == "ViT-L-14":
        if metric_type == "L2":
            collection_name = "multimodal_search_L14_L2"
        else:
            collection_name = "multimodal_search_L14"
        model = model_L14
        preprocess = preprocess_L14
    elif model_name == "ViT-H-14":
        if metric_type == "L2":
            collection_name = "multimodal_search_H14_L2"
        else:
            collection_name = "multimodal_search_H14"
        model = model_H14
        preprocess = preprocess_H14
    elif model_name == "ViT-B-16":
        if metric_type == "L2":
            collection_name = "multimodal_search_B16_L2"
        else:
            collection_name = "multimodal_search_B16"
        model = model_B16
        preprocess = preprocess_B16

    emb_collection = Collection(collection_name)

    model_infos = {"model": model, "preprocess": preprocess, "device": device}

    for i in range(0, len(image_ids), process_batch_size):

        print(datetime.now(), " Start processing of embedding images in batch: ", i)

        # 取出当前批次的 image_ids 和 image_paths
        batch_image_ids = image_ids[i:i + process_batch_size]
        batch_image_paths = image_paths[i:i + process_batch_size]
        # get image emb
        raw_image = list(executor.map(image_open, batch_image_paths))
        image_emb = get_embedding("image", raw_image, model_infos, batch_mode=True)
        emb_collection.insert([batch_image_ids, image_emb, [user_id] * len(batch_image_ids), batch_image_paths])
        emb_collection.flush()

        print(datetime.now(), " Completed processing image embeddings in batch: ", i + process_batch_size - 1)

    emb_collection.load()

    return "调用成功"
